{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX Script `optimizer`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import tempfile\n",
    "import numpy as np\n",
    "import onnx\n",
    "import onnxruntime\n",
    "\n",
    "from onnxscript import optimizer\n",
    "from onnxscript.utils import evaluation_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_folder_path = \"./temp/e2e_models\"\n",
    "# List all entries in the directory and filter for directories\n",
    "model_names = [entry.name for entry in model_folder_path.iterdir() if entry.is_dir()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skip storing constant folded nvalue self_4 due to large size 1050624.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_4 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_5 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_10 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_11 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_16 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_17 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_22 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_23 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_28 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_29 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_34 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_35 due to large size 2097152.\n",
      "Skip storing constant folded nvalue result_1 due to large size 10240000.\n",
      "Skip storing constant folded nvalue t_36 due to large size 10240000.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(value_info): 7144\n",
      "Applied 44 of general pattern rewrite rules.\n",
      "len(value_info): 860\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skip storing constant folded nvalue t_4 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_5 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_10 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_11 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_16 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_17 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_22 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_23 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_28 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_29 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_34 due to large size 2097152.\n",
      "Skip storing constant folded nvalue t_35 due to large size 2097152.\n",
      "Skip storing constant folded nvalue torch_nn_modules_linear_Linear_lm_head_1_8_t_36 due to large size 10240000.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(value_info): 4218\n",
      "Applied 0 of general pattern rewrite rules.\n",
      "len(value_info): 768\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skip storing constant folded nvalue result_1 due to large size 5120000.\n",
      "Skip storing constant folded nvalue torch_nn_modules_linear_Linear_classifier_1_6_t due to large size 5120000.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(value_info): 23992\n",
      "Applied 0 of general pattern rewrite rules.\n",
      "len(value_info): 2069\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skip storing constant folded nvalue torch_nn_modules_linear_Linear_classifier_1_6_t due to large size 5120000.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(value_info): 10935\n",
      "Applied 0 of general pattern rewrite rules.\n",
      "len(value_info): 1755\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skip storing constant folded nvalue result_1 due to large size 2048000.\n",
      "Skip storing constant folded nvalue torch_nn_modules_linear_Linear_fc_1_12_t due to large size 2048000.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(value_info): 5182\n",
      "Applied 0 of general pattern rewrite rules.\n",
      "len(value_info): 480\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skip storing constant folded nvalue torch_nn_modules_linear_Linear_fc_1_12_t due to large size 2048000.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(value_info): 2584\n",
      "Applied 0 of general pattern rewrite rules.\n",
      "len(value_info): 217\n"
     ]
    }
   ],
   "source": [
    "for model_name in model_names:\n",
    "    if model_name == \"torchscript_model\":\n",
    "        continue\n",
    "    model_dir = Path(model_folder_path) / model_name / \"dynamo\"\n",
    "    model_path = model_dir / f\"{model_name}_dynamo.onnx\"\n",
    "    model = onnx.load(model_path)\n",
    "    model = optimizer.optimize(model, onnx_shape_inference=False)\n",
    "    with tempfile.TemporaryDirectory() as tmp_folder:\n",
    "        tmp_folder = Path(tmp_folder)\n",
    "        optimized_model_path = tmp_folder / f\"{model_name}_opt.onnx\"\n",
    "        onnx.save(\n",
    "            model,\n",
    "            optimized_model_path,\n",
    "            save_as_external_data=True,\n",
    "            all_tensors_to_one_file=True,\n",
    "        )\n",
    "\n",
    "        session = onnxruntime.InferenceSession(\n",
    "            optimized_model_path, providers=(\"CPUExecutionProvider\",)\n",
    "        )\n",
    "\n",
    "        inputs, expected_outputs = evaluation_utils.load_test_data(\n",
    "            model_dir, [i.name for i in model.graph.input]\n",
    "        )\n",
    "\n",
    "        input_names = [i.name for i in session.get_inputs()]\n",
    "        assert set(input_names) == set(inputs.keys())\n",
    "\n",
    "        outputs = session.run(None, inputs)\n",
    "        # Free the session so the model file is no longer used\n",
    "        del session\n",
    "\n",
    "        for output, expected_output in zip(outputs, expected_outputs):\n",
    "            np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
